思路并不难的一道题,实现起来有点麻烦

很明显,时间具有可二分性,对于一个时间t,若时间t内军队可以完全覆盖,

那么比t更大的时间一定可以

所以就可以先二分时间,那么如何判断每一个时间是否可行

因为军队是可以同时移动的,并且军队在深度小的结点是比在深度大的结点更优的

所以对于每一个军队,计算出在当前时间限制t内往上跳能跳的深度最小的结点,

如果这个军队可以跳到1号结点,就让它暂时闲置在1号结点的子节点,然后可以枚

举1号结点的所有子节点,计算出以这个子节点为根的子树除开它的根是否已经被

覆盖,对于所有闲置的军队,若它剩余的时间不够它到根节点在返回,就让它直接

驻扎在当前结点。

结合代码理解

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 5e4 + 5;
typedef long long ll;

ll dist[N];
int dt[N], fa[N][25], tot, sn[N], cnt, num;
int head[N], ver[2 * N], edge[2 * N], net[2 * N], idx;
int is[N];
bool use[N];
struct Move
{
int p;
ll t;
bool operator < (const Move a) const
{return t < a.t;}
} a[N], ned[N];

void add(int a, int b, int c)
{
net[++idx] = head[a], ver[idx] = b, edge[idx] = c, head[a] = idx;
}

int dfs2(int u, int f)//返回值为当前结点的子树覆盖的叶子结点数
{
int res = 0;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (v == f)
continue;
res += dfs2(v, u);
}
if (f == 1)
{
if (res == sn[u])
return 1;
return 0;
}
if (is[u])
return sn[u];
return res;
}

bool check(ll tim)
{
cnt = num = 0;
memset(is, 0, sizeof(is));
memset(use, 0, sizeof(use));
for (int i = 1; i <= tot; i++)
{
int u = dt[i], v = dt[i];
for (int i = 20; i >= 0; i--)
if (dist[u] - dist[fa[v][i]] <= tim && fa[v][i] != 1 && fa[v][i] != 0)
v = fa[v][i];
if (fa[v][0] == 1)
a[++cnt] = Move({v, tim - dist[u]});
is[v]++;
}//所有军队往上跳,并储存闲置军队
for (int i = head[1]; i; i = net[i])
{
int u = ver[i];
use[u] = dfs2(u, 1);
}//判断是否已经被覆盖
for (int i = 1; i <= cnt; i++)
{
if (a[i].t < dist[a[i].p] && !use[a[i].p])
use[a[i].p] = true, a[i].t = 0;
}//若有军队剩余的时间不够在返回,就让它直接驻扎在该结点
for (int i = head[1]; i; i = net[i])
if (!use[ver[i]])
ned[++num] = Move({ver[i], edge[i]});//储存所有还未被覆盖的子树
sort(a + 1, a + 1 + cnt);
sort(ned + 1, ned + 1 + num);
int now = 1;
if (num > cnt)
return 0;
for (int i = 1; i <= num; i++)//贪心匹配
{
while (a[now].t < ned[i].t && now <= cnt)
now++;
if (now > cnt)
return false;
now++;
}
return true;
}

void dfs1(int u, int f, int e)
{
fa[u][0] = f, dist[u] = dist[f] + e;
for (int i = 1; i <= 20; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
int flag = 1;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (v == f)
continue;
flag = 0;
dfs1(v, u, edge[i]);
sn[u] += sn[v];
}
sn[u] += flag;
}

int main()
{
int n, m;
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w), add(v, u, w);
}
dfs1(1, 0, 0);
scanf("%d", &m);
for (int i = 1; i <= m; i++)
scanf("%d", &dt[++tot]);
ll l = 0, r = 1e15, ans = -1;
while (l <= r)
{
ll mid = (l + r) >> 1;
if (check(mid))
ans = mid, r = mid - 1;
else
l = mid + 1;
}
printf("%lld", ans);
}